import os
import logging
import time
import glob

import numpy as np
from tqdm import tqdm
import torch
import torch.utils.data as data



from isaacgymenvs.ddim.models.diffusion_controlseq import Model, ModelSimpleMLP, ModelCond, ModelInvDyn, ModelDiffResMLP, ModelInvDynObjState, ModelInvDynTokenMimicking, ModelInvDynMaskedCond, ModelInvDynMaskedObjMotionCond, WorldModel, WorldModelDeltaActions, ModelInvDynObjMotionPred, QValueModel, VModel, ModelInvDynFingerPos, InverseDynamicsModel
from isaacgymenvs.ddim.models.ema import EMAHelper
from isaacgymenvs.ddim.functions import get_optimizer
from isaacgymenvs.ddim.functions.losses import loss_registry
from isaacgymenvs.ddim.datasets import get_dataset, data_transform, inverse_data_transform
from isaacgymenvs.ddim.functions.ckpt_util import get_ckpt_path

from isaacgymenvs.ddim.models.running_mean_std import RunningMeanStd

from isaacgymenvs.ddim.datasets.controlseq import ControlSeq, ControlSeqStochastic, ControlSeqStochasticTokenMimicing, ControlSeqWorldModel

import torchvision.utils as tvu

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import torch.nn.functional as F

import pytorch_kinematics as pk

try:
    from torchvision.ops import box_convert
    from torchvision.transforms.functional import quaternion_from_matrix
except:
    pass

def torch2hwcuint8(x, clip=False):
    if clip:
        x = torch.clamp(x, -1, 1)
    x = (x + 1.0) / 2.0
    return x

def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):
        return 1 / (np.exp(-x) + 1)

    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas


class Diffusion(object):
    def __init__(self, args, config, device=None):
        self.args = args
        self.config = config
        if device is None:
            device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
        self.device = device

        self.model_var_type = config.model.var_type
        betas = get_beta_schedule(
            beta_schedule=config.diffusion.beta_schedule,
            beta_start=config.diffusion.beta_start,
            beta_end=config.diffusion.beta_end,
            num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
        )
        betas = self.betas = torch.from_numpy(betas).float().to(self.device)
        self.num_timesteps = betas.shape[0]

        alphas = 1.0 - betas
        alphas_cumprod = alphas.cumprod(dim=0)
        alphas_cumprod_prev = torch.cat(
            [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
        )
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        if self.model_var_type == "fixedlarge":
            self.logvar = betas.log()
            # [posterior_variance[1:2], betas[1:]], dim=0).log()
        elif self.model_var_type == "fixedsmall":
            self.logvar = posterior_variance.clamp(min=1e-20).log()
        
        self.invdyn_model_arch = config.invdyn.model_arch
        self.history_length = config.invdyn.history_length
        self.future_length  = config.invdyn.future_length
        
        self.normalize_input = config.invdyn.normalize_input
        self.normalize_output = config.invdyn.normalize_output
        
        self.obj_state_predictor = config.invdyn.obj_state_predictor
        
        try:
            self.pred_extrin = self.config.invdyn.pred_extrin
            self.extrin_history_length = 30
            self.sa_mean_std = RunningMeanStd((self.extrin_history_length, 32)).to(self.device)
            self.sa_mean_std.train()
            
            self.running_mean_std = RunningMeanStd(self.history_length * (32)).to(self.device)
            self.running_mean_std.train()
        except:
            if self.normalize_input:
                self.running_mean_std = RunningMeanStd(self.history_length * (32)).to(self.device)
                self.running_mean_std.train()
            self.pred_extrin = False
        
        self.invdyn_masked_cond = self.config.invdyn.masked_cond
        self.invdyn_masked_obj_motion_cond = self.config.invdyn.invdyn_masked_obj_motion_cond
        self.invdyn_model_train_tokenizer = self.config.invdyn.train_tokenizer
        self.invdyn_model_resume_tokenizer = self.config.invdyn.resume_tokenizer
        
        self.invdyn_model_optimize_via_fingertip_pos = self.config.invdyn.optimize_via_fingertip_pos
        self.invdyn_finger_idx = self.config.invdyn.finger_idx
        self.invdyn_joint_idx = self.config.invdyn.joint_idx
        self.invdyn_train_obj_motion_pred_model = self.config.invdyn.train_obj_motion_pred_model
        
        self.invdyn_model_fingertip_rot_coef = self.config.invdyn.fingertip_rot_coef
        
        self.invdyn_model_train_finger_pos_tracking_model = self.config.invdyn.train_finger_pos_tracking_model
        
        self.q_value_model_expectile_regression = self.config.invdyn.q_value_model_expectile_regression
        self.q_value_model_w_v_model = self.config.invdyn.q_value_model_w_v_model
        self.invdyn_model_train_value_network = self.config.invdyn.train_value_network
        self.invdyn_hist_context_length = self.config.invdyn.hist_context_length
        
        self.invdyn_train_residual_wm = self.config.invdyn.train_residual_wm
        self.invdyn_prev_wm_ckpt = self.config.invdyn.prev_wm_ckpt
        self.invdyn_w_hand_root_ornt = self.config.invdyn.w_hand_root_ornt
        
        self.wm_as_invdyn_prediction = self.config.invdyn.wm_as_invdyn_prediction
        
        self.stack_wm_history = self.config.invdyn.stack_wm_history




    
    def build_pk_chain_finger(self, finger_idx):
        
        leap_urdf_path = "../RL/assets/leap_hand/leap_hand_right.urdf"
        chain = pk.build_chain_from_urdf(open(leap_urdf_path).read()) 
        chain = chain.to(dtype=torch.float32, device=self.device)
        self.chain = chain
        self.isaac_order_to_pk_order = [_ for _ in range(4)] + [_ + 8 for _ in range(0, 8)] + [4, 5, 6, 7] # isaac order to the pk order
        self.isaac_order_to_pk_order = torch.tensor(self.isaac_order_to_pk_order, dtype=torch.long, device=self.device)
        self.fingertip_names = [
            'index_tip_head', 'thumb_tip_head', 'middle_tip_head', 'ring_tip_head'
        ]
    
    def forward_pk_chain_for_finger_pos(self, joint_angles, finger_idx):
        
        tot_joint_angles = torch.zeros((joint_angles.shape[0], 16), dtype=torch.float32, device=self.device)
        finger_joint_idxes = [ _ for _ in range(finger_idx  * 4 , (finger_idx + 1) * 4) ]
        finger_joint_idxes = torch.tensor(finger_joint_idxes, dtype=torch.long, device=self.device)
        tot_joint_angles[:, finger_joint_idxes] = joint_angles.clone()
        
        pk_joint_angles = tot_joint_angles[:, self.isaac_order_to_pk_order]
        tg_batch = self.chain.forward_kinematics(pk_joint_angles)
        finger_trans_matrix = tg_batch[self.fingertip_names[finger_idx]].get_matrix()[:, :3, 3]
        
        finger_rot_matrix = tg_batch[self.fingertip_names[finger_idx]].get_matrix()[:, :3, :3]
        
        return finger_trans_matrix, finger_rot_matrix
        

    def train_world_model_delta_actions(self):
        
        args, config = self.args, self.config
        tb_logger = self.config.tb_logger
        dataset, test_dataset = get_dataset(args, config)
        train_loader = data.DataLoader(
            dataset,
            batch_size=config.training.batch_size,
            shuffle=True,
            num_workers=config.data.num_workers,
        )
        
        model = WorldModel(self.config)
        model = model.to(self.device)
        
        ### Load real world model's checkpoint ###
        real_states = torch.load(args.real_world_model_path, map_location=self.config.device)
        model.load_state_dict(real_states[0], strict=True)
        model.eval()
        
        sim_model = WorldModel(self.config)
        sim_model = sim_model.to(self.device)
        ### Load sim world model's checkpoint ###
        
        
        sim_states = torch.load(args.sim_world_model_path, map_location=self.config.device)
        sim_model.load_state_dict(sim_states[0], strict=True)
        sim_model.eval()
        
        delta_action_model = WorldModelDeltaActions(self.config)
        delta_action_model = delta_action_model.to(self.device)
        
        
        if self.invdyn_model_optimize_via_fingertip_pos:
            self.build_pk_chain_finger(self.invdyn_finger_idx)
        
        
        optimizer = get_optimizer(self.config, delta_action_model.parameters())

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(delta_action_model)
        else:
            ema_helper = None

        start_epoch, step = 0, 0
        
        logging_step_interval = 20000
        
        for epoch in range(start_epoch, self.config.training.n_epochs):
            data_start = time.time()
            data_time = 0
            
            ep_reg_sigma = []
            ep_extrin = []
            
            loss_name_to_ep_loss = { 'loss': [] , 'task_loss': [], 'loss_wocompensator': []}
            
            i = 0
            
            
            for data_batch in tqdm(train_loader):
                state = data_batch['state']
                # motion = data_batch['motion']
                actions = data_batch['action']
                
                nex_state = data_batch['nex_state']
                
                
                data_time += time.time() - data_start
                # model.train()
                # model.eval()
                
                state = state.to(self.device)
                nex_state = nex_state.to(self.device)
                actions = actions.to(self.device)
                
                # input_dict = {
                #     'state': state,
                #     'nex_state': nex_state,
                #     'action': actions
                # }
                
                input_dict = {
                    key: data_batch[key].to(self.device) for key in data_batch
                }
                
                # pred_nex_state_sim = sim_model(input_dict)
                pred_nex_state_sim = nex_state.clone()
                
                with torch.no_grad():
                    detached_input_dict = {
                        key: input_dict[key].detach() for key in input_dict
                    }
                    pred_nex_state_wo_compensator = model(detached_input_dict)
                    diff_acts_wo_compensator = torch.sum((pred_nex_state_wo_compensator - pred_nex_state_sim) ** 2, dim=-1)
                    diff_acts_wo_compensator = diff_acts_wo_compensator.mean().item()
                    loss_name_to_ep_loss['loss_wocompensator'].append(diff_acts_wo_compensator)
                
                
                ### delta action model -- delta action ### delta action ### ### 
                pred_delta_actions = delta_action_model(input_dict)
                compensate_actions = actions[..., -pred_delta_actions.size(-1):] + pred_delta_actions
                # compensate_actions = actions.clone()
                # compensate_actions[..., -pred_delta_actions.size(-1):] = pred_delta_actions
                compensate_actions = torch.cat(
                    [ actions[..., : -pred_delta_actions.size(-1)] , compensate_actions] , dim=-1
                )
                
                
                # 
                input_dict_real = {
                    key: data_batch[key].to(self.device) for key in data_batch
                }
                input_dict_real.update(
                    {
                        'state': state,
                        'nex_state': nex_state,
                        'action': compensate_actions
                    }
                ) 
                # 
                
                pred_nex_state = model(input_dict_real)
                
                   
                
                
                if self.invdyn_model_optimize_via_fingertip_pos:
                    # unscaled_state = train_loader.dataset.unscale_states(state)
                    unscaled_pred_nex_state = train_loader.dataset.unscale_states(pred_nex_state)
                    unscaled_pred_nex_state_sim = train_loader.dataset.unscale_states(pred_nex_state_sim)
                    # st_finger_pos = self.forward_pk_chain_for_finger_pos(unscaled_state, self.invdyn_finger_idx)
                    pred_nex_finger_pos, pred_nex_finger_rot_quat = self.forward_pk_chain_for_finger_pos(unscaled_pred_nex_state, self.invdyn_finger_idx)
                    pred_nex_finger_pos_sim, pred_nex_finger_rot_quat_sim = self.forward_pk_chain_for_finger_pos(unscaled_pred_nex_state_sim, self.invdyn_finger_idx)
                    # Finger position difference # 
                    diff_pred_fingerpos_w_sim = torch.sum(
                        ( pred_nex_finger_pos - pred_nex_finger_pos_sim ) ** 2, dim=-1
                    )
                    
                    diff_acts = diff_pred_fingerpos_w_sim.mean( ) #  + finger_rot_dist.mean() * self.invdyn_model_fingertip_rot_coef
                else:
                    diff_acts = torch.sum((pred_nex_state - pred_nex_state_sim) ** 2, dim=-1)
                    diff_acts = diff_acts.mean()
                
                
                loss = diff_acts
                
                loss_name_to_ep_loss['loss'].append(loss.detach().item())
                loss_name_to_ep_loss['task_loss'].append(diff_acts.detach().item())
                
                step += 1

                optimizer.zero_grad()
                loss.backward()
                
                
                try:
                    torch.nn.utils.clip_grad_norm_(
                        delta_action_model.parameters(), config.optim.grad_clip
                    )
                except Exception:
                    pass
                optimizer.step()

                if self.config.model.ema:
                    ema_helper.update(delta_action_model)
                    
                if step % logging_step_interval == 0:
                    avg_ep_loss = sum(loss_name_to_ep_loss['loss']) / float(len(loss_name_to_ep_loss['loss']))
                    if len(ep_reg_sigma) == 0:
                        avg_ep_reg_sigma = 0.0
                    else:
                        avg_ep_reg_sigma = sum(ep_reg_sigma) / float(len(ep_reg_sigma))
                    
                    if len(ep_extrin) == 0:
                        avg_ep_extrin = 0.0
                    else:
                        avg_ep_extrin = sum(ep_extrin) / float(len(ep_extrin))
                    
                    tb_logger.add_scalar("loss", avg_ep_loss, global_step=step)
                    
                    logging_info_str = f"step: {step}, loss: {avg_ep_loss}, "
                    for key in loss_name_to_ep_loss:
                        if len(loss_name_to_ep_loss[key]) > 0:
                            avg_ep_loss = sum(loss_name_to_ep_loss[key]) / float(len(loss_name_to_ep_loss[key]))
                            logging_info_str += f"{key}: {avg_ep_loss}, "
                    logging_info_str += f"avg_ep_extrin: {avg_ep_extrin}, data time: {data_time / (i+1)}"
                    logging.info(logging_info_str)

                if step % self.config.training.snapshot_freq == 0 or step == 1:
                    states = [
                        delta_action_model.state_dict(),
                        optimizer.state_dict(),
                        epoch,
                        step,
                    ]
                    
                    if self.normalize_input:
                        states.append(self.running_mean_std.state_dict())
                    
                    if self.config.model.ema:
                        states.append(ema_helper.state_dict())

                    torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))

                data_start = time.time()
                
                i += 1
                
            states = [
                delta_action_model.state_dict(),
                optimizer.state_dict(),
                epoch,
                step,
            ]
            
            if self.normalize_input:
                states.append(self.running_mean_std.state_dict())
            
            if self.config.model.ema:
                states.append(ema_helper.state_dict())
        
    

    def _data_transform(self, x):
        x_transformed = data_transform(self.config, x)
        return x_transformed
    
    
    

    def init_models(self, ckpt_fn=None):
        
        if self.args.model_type == 'invdyn':
            if self.config.invdyn.train_finger_pos_tracking_model:
                model = ModelInvDynFingerPos(self.config)
            elif self.config.invdyn.token_mimicking:
                model = ModelInvDynTokenMimicking(self.config)
            elif self.invdyn_masked_cond:
                model = ModelInvDynMaskedCond(self.config)
            elif self.config.invdyn.obj_state_predictor:
                model = ModelInvDynObjState(self.config)
            else:
                model = ModelInvDyn(self.config)
        else:
            if self.config.model.subtype == "unet":
                # model = Model(self.config)
                if self.config.model.cond:
                    model = ModelCond(self.config)
                else:
                    model = Model(self.config)
            elif self.config.model.subtype == "mlp":
                model = ModelSimpleMLP(self.config)
            elif self.config.model.subtype == "resmlp":
                model = ModelDiffResMLP(self.config)
            else:
                raise NotImplementedError
        if ckpt_fn is None:
            if getattr(self.config.sampling, "ckpt_id", None) is None:
                print(f"log_path: {self.args.log_path}")
                states = torch.load(
                    os.path.join(self.args.log_path, "ckpt.pth"),
                    map_location=self.config.device,
                )
            else:
                states = torch.load(
                    os.path.join(
                        self.args.log_path, f"ckpt_{self.config.sampling.ckpt_id}.pth"
                    ),
                    map_location=self.config.device, 
                )
            model = model.to(self.device)
            # model = torch.nn.DataParallel(model)
            
            # 
            one_key = list(states[0].keys())[0]
            if 'module.' in one_key:
                # print(f"ema_keys: {states[-1].keys()}")
                new_key_to_weights = {}
                for cur_key in states[0]:
                    cur_key_wo_module = cur_key[len('module.'): ]
                    new_key_to_weights[cur_key_wo_module] = states[0][cur_key]
                # new_ema_key_to_weights = {}
                # for cur_key in states[-1]:
                #     cur_key_wo_module = cur_key[len('module.'): ]
                #     new_ema_key_to_weights[cur_key_wo_module] = states[-1][cur_key]
            else:
                new_key_to_weights = states[0]
                # new_ema_key_to_weights = states[-1]
            
            one_ema_key = list(states[-1].keys())[0]
            if 'module.' in one_ema_key:
                new_ema_key_to_weights = {}
                for cur_key in states[-1]:
                    cur_key_wo_module = cur_key[len('module.'): ]
                    new_ema_key_to_weights[cur_key_wo_module] = states[-1][cur_key]
            else:
                new_ema_key_to_weights = states[-1]
            
            # model.load_state_dict(states[0], strict=True)
            model.load_state_dict(new_key_to_weights, strict=True)
            
            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(model)
                # ema_helper.load_state_dict(states[-1])
                ema_helper.load_state_dict(new_ema_key_to_weights)
                ema_helper.ema(model)
            else:
                ema_helper = None
            
        else:
            states = torch.load(ckpt_fn, map_location=self.config.device)
            model = model.to(self.device)
            # model = torch.nn.DataParallel(model)
            model.load_state_dict(states, strict=True)

        model.eval()
        self.model = model
        self.ema_helper = ema_helper
        
        if self.pred_extrin:
            print(f"Initializing extrins")
            self.sa_mean_std.load_state_dict(states[4])
            self.running_mean_std.load_state_dict(states[5])
            self.sa_mean_std.eval()
            self.running_mean_std.eval()
            
    
    
    
    
    def forward_states_for_actions(self, history_states, future_ref=None, history_extrin=None, hist_context=None):
        
        if self.pred_extrin: # forward states for actions #
            # history_extrin = history_extrin.to(self.device)
            history_extrin = self.sa_mean_std(history_extrin.detach())
            # history_extrin = history_extrin.to(self.device)
            history_states = self.running_mean_std(history_states.detach())
        
        if self.invdyn_masked_cond:
            pred_actions = self.model(history_states, history_extrin=history_extrin)
        elif self.config.invdyn.obj_state_predictor:
            pred_actions = self.model(history_states, history_extrin=history_extrin)
        else:
            pred_actions = self.model(history_states, future_ref, history_extrin=history_extrin, hist_context=hist_context)
        return pred_actions




def setup_ddp(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    
    
    
def setup_ddp_flexible():
    
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    print(f"Start running basic DDP example on rank {rank}.")
    device_id = rank % torch.cuda.device_count()
    return device_id





def train_invdyn_ddp(self):
    
    device_id = setup_ddp_flexible()
    # device_id = setup_ddp_flexible
    rank = device_id
    
    world_size = 8
    
    # setup_ddp(rank, world_size)
    
    self.config.device = device_id
    
    args, config = self.args, self.config
    tb_logger = self.config.tb_logger
    dataset, test_dataset = get_dataset(args, config)
    
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    
    train_loader = data.DataLoader(
        dataset,
        batch_size=config.training.batch_size,
        # shuffle=True,
        num_workers=config.data.num_workers,
        sampler=sampler,
    )
    
    
    if self.invdyn_masked_obj_motion_cond:
        model = ModelInvDynMaskedObjMotionCond(self.config)
    elif self.invdyn_masked_cond:
        model = ModelInvDynMaskedCond(self.config)
    elif self.obj_state_predictor:
        model = ModelInvDynObjState(self.config)
    elif self.invdyn_model_train_finger_pos_tracking_model:
        model = ModelInvDynFingerPos(self.config)   
    else:
        model = ModelInvDyn(config)
    
    model = model.to(rank)
    model = DDP(model, device_ids=[rank])
    
    self.sa_mean_std = self.sa_mean_std.to(rank ) #  DDP(self.sa_mean_std, device_ids=[rank])
    self.running_mean_std = self.running_mean_std.to(rank) # DDP(self.running_mean_std, device_ids=[rank])
    
    
    optimizer = get_optimizer(self.config, model.parameters())

    if self.config.model.ema:
        ema_helper = EMAHelper(mu=self.config.model.ema_rate)
        ema_helper.register(model)
    else:
        ema_helper = None

    start_epoch, step = 0, 0
    
    logging_step_interval = 20000
    
    for epoch in range(start_epoch, self.config.training.n_epochs):
        data_start = time.time()
        data_time = 0
        
        ep_reg_sigma = []
        ep_extrin = []
        
        loss_name_to_ep_loss = { 'loss': [] , 'task_loss': []}
        
        i = 0
        
        
        for data_batch in tqdm(train_loader):
            state = data_batch['state']
            # motion = data_batch['motion']
            actions = data_batch['action']
            
            if self.args.train_world_model_via_invdyn:
                # future_hand_qpos = databa
                motion = data_batch['action']
                actions = data_batch['future_hand_qpos']
            elif self.args.future_type == 'hand_motion':
                motion = data_batch['motion']
            elif self.args.future_type == 'obj_motion':
                motion = data_batch['obj_euler_diff']
            elif self.args.future_type == 'hand_obj_motion':
                hand_motion = data_batch['motion']
                obj_motion = data_batch['obj_euler_diff']
                motion = torch.cat(
                    [hand_motion, obj_motion], dim=-1
                )
            elif self.args.future_type == 'obj_rot_dir':
                motion = data_batch['obj_rot_dir']
            else:
                raise NotImplementedError
            
            if self.invdyn_model_train_value_network:
                state = data_batch['value_net_hist_info']
                motion = data_batch['value_net_nex_action']
                actions = data_batch['value_net_nex_value']
            
            # n = x.size(0)
            data_time += time.time() - data_start
            model.train()
            
            state = state.to(rank)
            motion = motion.to(rank)
            actions = actions.to(rank)
            
            if self.pred_extrin:
                # gt_extrin, history_extrin
                history_extrin = data_batch['history_extrin']
                history_extrin = history_extrin.to(rank)
                history_extrin = self.sa_mean_std(history_extrin.detach())
                gt_extrin = data_batch['gt_extrin']
                gt_extrin = gt_extrin.to(rank)
                state = self.running_mean_std(state.detach())
            else:
                if self.normalize_input:
                    state = self.running_mean_std(state.detach())
                history_extrin = None
            
            
            if self.invdyn_model_train_finger_pos_tracking_model:
                finger_pos_w_motion_ref = data_batch['finger_pos_w_motion_ref'].to(rank)
                actions = data_batch['target_finger_qtars'].to(rank)
                pred_actions = model(state, finger_pos_w_motion_ref)
            
            elif self.invdyn_masked_obj_motion_cond:
                input_dict = {
                    key: data_batch[key].to(rank) for key in data_batch
                }
                pred_actions = model(input_dict)
            elif self.invdyn_masked_cond:
                input_dict = {
                    'history_obs': state, 
                    'obj_motion_ref': data_batch['obj_euler_diff'].to(rank),
                    'hand_motion_ref': data_batch['motion'].to(rank),
                    'hand_cond_mask': data_batch['hand_cond_mask'].to(rank),
                    'obj_cond_mask': data_batch['obj_cond_mask'].to(rank),
                }
                pred_actions = model(input_dict)
            elif self.obj_state_predictor:
                pred_actions = model(state, history_extrin=history_extrin)
            elif self.invdyn_hist_context_length > 0:
                hist_context = data_batch['hist_context'].to(rank)
                pred_actions = model(state, motion, history_extrin=history_extrin, hist_context=hist_context)
            elif self.invdyn_w_hand_root_ornt:
                hand_root_ornt = data_batch['hand_root_ornt'].to(rank)
                state = torch.cat(
                    [ state, hand_root_ornt ], dim=-1
                )
                pred_actions = model(state, motion, history_extrin=history_extrin)
            else:
                pred_actions = model(state, motion, history_extrin=history_extrin)
            
            
            if self.invdyn_model_train_tokenizer:
                loss = 0
                for key in pred_actions: 
                    loss += pred_actions[key]
                for key in pred_actions:
                    if key not in loss_name_to_ep_loss:
                        loss_name_to_ep_loss[key] = []
                    loss_name_to_ep_loss[key].append(pred_actions[key].detach().item())
                loss_name_to_ep_loss['loss'].append(loss.detach().item())
            else:
                diff_acts = torch.sum((pred_actions - actions) ** 2, dim=-1)
                diff_acts = diff_acts.mean()
                
                if self.invdyn_model_arch == 'resmlp_gaussian':
                    reg_sigma = model.reg_sigma
                    reg_sigma_coef = 1e-4
                    loss = diff_acts + reg_sigma  * reg_sigma_coef
                else:
                    loss = diff_acts
                
                loss_name_to_ep_loss['loss'].append(loss.detach().item())
                loss_name_to_ep_loss['task_loss'].append(diff_acts.detach().item())
                
                
            if self.pred_extrin:
                diff_extrin = torch.sum((model.module.extrin_pred - gt_extrin) ** 2, dim=-1)
                diff_extrin = diff_extrin.mean()
                loss = loss + diff_extrin
                ep_extrin.append(diff_extrin.detach().item())
            
            step += 1



            optimizer.zero_grad()
            loss.backward()
            
            
            if self.invdyn_model_arch == 'resmlp_gaussian':
                ep_reg_sigma.append(reg_sigma.detach().item())

            try:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.optim.grad_clip
                )
            except Exception:
                pass
            optimizer.step()

            if self.config.model.ema:
                ema_helper.update(model)
                
                
            ### Logging ###
            if step % logging_step_interval == 0 and rank == 0:
                avg_ep_loss = sum(loss_name_to_ep_loss['loss']) / float(len(loss_name_to_ep_loss['loss']))
                if len(ep_reg_sigma) == 0:
                    avg_ep_reg_sigma = 0.0
                else:
                    avg_ep_reg_sigma = sum(ep_reg_sigma) / float(len(ep_reg_sigma))
                
                if len(ep_extrin) == 0:
                    avg_ep_extrin = 0.0
                else:
                    avg_ep_extrin = sum(ep_extrin) / float(len(ep_extrin))
                
                tb_logger.add_scalar("loss", avg_ep_loss, global_step=step)
                
                logging_info_str = f"step: {step}, loss: {avg_ep_loss}, "
                for key in loss_name_to_ep_loss:
                    if len(loss_name_to_ep_loss[key]) > 0:
                        avg_ep_loss = sum(loss_name_to_ep_loss[key]) / float(len(loss_name_to_ep_loss[key]))
                        # tb_logger.add_scalar(key, avg_ep_loss, global_step=step)
                        logging_info_str += f"{key}: {avg_ep_loss}, "
                logging_info_str += f"reg_sigma: {avg_ep_reg_sigma}, avg_ep_extrin: {avg_ep_extrin}, data time: {data_time / (i+1)}"

                logging.info(logging_info_str)
            ### Logging ###
            
            
            ### Saving model ###
            if (step % self.config.training.snapshot_freq == 0 or step == 1)  and rank == 0:
                states = [
                    model.module.state_dict(),
                    optimizer.state_dict(),
                    epoch,
                    step,
                ]
                
                if self.pred_extrin:
                    states.append(self.sa_mean_std.state_dict())
                    states.append(self.running_mean_std.state_dict())
                
                else:
                    if self.normalize_input:
                        states.append(self.running_mean_std.state_dict())
                
                if self.config.model.ema:
                    states.append(ema_helper.state_dict())

                torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))
            ### Saving model ###        
            
            data_start = time.time()
            
            i += 1
        ### Saving model ###
        if rank == 0:
            states = [
                model.module.state_dict(),
                optimizer.state_dict(),
                epoch,
                step,
            ]
            
            if self.pred_extrin:
                states.append(self.sa_mean_std.state_dict())
                states.append(self.running_mean_std.state_dict())
            
            else:
                if self.normalize_input:
                    states.append(self.running_mean_std.state_dict())
            
            if self.config.model.ema:
                states.append(ema_helper.state_dict())

            torch.save(states, os.path.join(self.args.log_path, f"ckpt_ep{epoch}.pth"))
        ### Saving model ###     
        
        
        
        if rank == 0:
            states = [
                model.module.state_dict(),
                optimizer.state_dict(),
                epoch,
                step,
            ]
            
            if self.pred_extrin:
                states.append(self.sa_mean_std.state_dict())
                states.append(self.running_mean_std.state_dict())
            
            else:
                if self.normalize_input:
                    states.append(self.running_mean_std.state_dict())
            
            if self.config.model.ema:
                states.append(ema_helper.state_dict())
            torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))

    def cleanup():
        dist.destroy_process_group()
    cleanup()




def train_world_model_delta_actions_ddp(self): 
    
    device_id = setup_ddp_flexible()
    # device_id = setup_ddp_flexible
    rank = device_id
    
    world_size = 8
    
    # delta_action_scale = 1/24
    # delta_action_scale = 1/100
    
    # setup_ddp(rank, world_size)
    
    delta_action_scale = self.config.invdyn.delta_action_scale
    
    self.device = rank
    self.config.device = device_id
    
    args, config = self.args, self.config
    tb_logger = self.config.tb_logger
    dataset, test_dataset = get_dataset(args, config)
    
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    
    train_loader = data.DataLoader(
        dataset,
        batch_size=config.training.batch_size,
        # shuffle=True,
        num_workers=config.data.num_workers,
        sampler=sampler,
    )
    
    # 
    self.wm_w_neighbouring = False
    # self.wm_w_neighbouring = True
    
    self.per_joint_wm_for_fullhand_compensator = False
    self.per_joint_wm_for_fullhand_compensator = True
    
    
    
    self.add_nearing_finger = False
    # self.add_nearing_finger = True
    
    
    tmp_wm_history_length = self.config.invdyn.wm_history_length + 0
    
    
    
    self.multi_joint_single_wm = self.config.invdyn.multi_joint_single_wm
    self.multi_finger_single_wm = self.config.invdyn.multi_finger_single_wm
    self.single_hand_wm = self.config.invdyn.single_hand_wm
    self.fullhand_wobjstate_wm = self.config.invdyn.fullhand_wobjstate_wm
    
    
    
    if self.multi_joint_single_wm or self.multi_finger_single_wm or self.single_hand_wm or self.fullhand_wobjstate_wm:
        self.hand_dof_lower = dataset.hand_dof_lower.to(rank)
        self.hand_dof_upper = dataset.hand_dof_upper.to(rank)
        
        self.config.invdyn.finger_idx = -1
        self.config.invdyn.joint_idx = -1
        self.config.invdyn.wm_history_length = 10 #  2
        
        self.config.invdyn.add_nearing_neighbour = self.wm_w_neighbouring
        self.config.invdyn.add_nearing_finger = self.add_nearing_finger
        
        wm_ckpt_tag = 'xx'
        wm_ckpt_fn = f''
        
        if self.config.invdyn.wm_hist_length_in_delta_action > 0:
            self.config.invdyn.wm_history_length = self.config.invdyn.wm_hist_length_in_delta_action
            self.wm_hist_length_in_delta_action = self.config.invdyn.wm_hist_length_in_delta_action + 0
        else:
            self.wm_hist_length_in_delta_action = tmp_wm_history_length
        
        self.wm_model = WorldModel(self.config).to(rank)
        self.wm_model.load_state_dict(torch.load(wm_ckpt_fn)[0])
        
        ema_helper = EMAHelper(mu=self.config.model.ema_rate)
        ema_helper.register(self.wm_model)
        ema_helper.load_state_dict(torch.load(wm_ckpt_fn)[-1])
        ema_helper.ema(self.wm_model)
        
        self.wm_model.eval()
        
        self.config.invdyn.wm_history_length = tmp_wm_history_length

    
    else:
        model = WorldModel(self.config)
        model = model.to(rank)
        
        real_states = torch.load(args.real_world_model_path)
        model.load_state_dict(real_states[0], strict=True)
        model.eval()
    
    
    if self.per_joint_wm_for_fullhand_compensator:
        self.config.invdyn.finger_idx = -1
        self.config.invdyn.joint_idx = -1
        self.config.invdyn.wm_history_length = 10
        self.config.invdyn.hist_context_length = 0
    
    
    delta_action_model = WorldModelDeltaActions(self.config)
    delta_action_model = delta_action_model.to(self.device)
    delta_action_model = DDP(delta_action_model, device_ids=[rank])
    
    
    test_policy_model = False
    # test_policy_model = True
    
    import sys
    sys.path.append('../RL')
    
    from hora.algo.models.models import ActorCritic
    from hora.algo.models.running_mean_std import RunningMeanStd
    
    
    
    optimizer = get_optimizer(self.config, delta_action_model.parameters())

    if self.config.model.ema:
        ema_helper = EMAHelper(mu=self.config.model.ema_rate)
        ema_helper.register(delta_action_model)
    else:
        ema_helper = None

    start_epoch, step = 0, 0
    
    
    logging_step_interval = 1000
    
    joint_idx_to_loss_curve = {}
    joint_idx_to_ep_avg_loss = {}
    
    
    
    for epoch in range(start_epoch, self.config.training.n_epochs):
        data_start = time.time()
        data_time = 0
        
        ep_reg_sigma = []
        ep_extrin = []
        
        loss_name_to_ep_loss = { 'loss': [] , 'task_loss': [], 'loss_wocompensator': [], 'loss_compensated_policy': []}
        
        i = 0
        
        for data_batch in tqdm(train_loader):
            state = data_batch['state']
            # motion = data_batch['motion']
            actions = data_batch['action']
            
            nex_state = data_batch['nex_state']
            
            
            data_time += time.time() - data_start
            # model.train()
            # model.eval()
            
            state = state.to(self.device)
            nex_state = nex_state.to(self.device)
            actions = actions.to(self.device)
            
            
            input_dict = {
                key: data_batch[key].to(self.device) for key in data_batch
            }
            
            # pred_nex_state_sim = sim_model(input_dict)
            pred_nex_state_sim = nex_state.clone()
            
            
            ### delta action model -- delta action ### delta action ### ### 
            pred_delta_actions = delta_action_model(input_dict)
            
            pred_delta_actions = torch.clamp(pred_delta_actions, -1.0, 1.0) * delta_action_scale
            
            
            # compensate_actions = actions.clone()
            # compensate_actions[..., -pred_delta_actions.size(-1):] = pred_delta_actions
            
            ###### compenate actions #######
            if self.per_joint_wm_for_fullhand_compensator:
                nn_dofs = 16
                
                # pred_delta_actions
                if self.stack_wm_history:
                    unflatten_actions = actions[:, -1]
                    unflatten_actions = unflatten_actions.view(unflatten_actions.size(0), -1, 16)
                    
                    unflatten_actions = dataset._scale(unflatten_actions, self.hand_dof_lower, self.hand_dof_upper).float()
                    compensated_actions = unflatten_actions + pred_delta_actions
                    
                    compensated_actions = dataset._unscale(compensated_actions, self.hand_dof_lower, self.hand_dof_upper).float()
                    stacked_compensated_actions = compensated_actions
                    
                    state = state[:, -1]
                    
                else:
                    unflatten_actions = actions.view(actions.size(0), -1, 16)
                    last_unflatten_actions = unflatten_actions[:, -1, :]
                    
                    last_unflatten_actions = dataset._scale(last_unflatten_actions, self.hand_dof_lower, self.hand_dof_upper).float()
                    
                    compensated_actions = last_unflatten_actions + pred_delta_actions
                    compensated_actions = dataset._unscale(compensated_actions, self.hand_dof_lower, self.hand_dof_upper).float()
                    
                    stacked_compensated_actions = torch.cat(
                        [ unflatten_actions[:, :-1], compensated_actions.unsqueeze(1) ], dim=1
                    )
                
                unflatten_states = state.view(state.size(0), -1, 16)
                
                if self.multi_joint_single_wm or self.multi_finger_single_wm or self.single_hand_wm or self.fullhand_wobjstate_wm:
                    
                    if self.fullhand_wobjstate_wm:
                        unflatten_states = unflatten_states[:, -self.wm_hist_length_in_delta_action:]
                        stacked_compensated_actions = stacked_compensated_actions[:, -self.wm_hist_length_in_delta_action:]
                        wm_states = unflatten_states.contiguous().view(unflatten_states.size(0), -1).contiguous()
                        wm_actions = stacked_compensated_actions.contiguous().view(stacked_compensated_actions.size(0), -1).contiguous()
                        cur_obj_state = data_batch['cur_obj_pose_state'].to(self.device)
                        wm_states = torch.cat([wm_states, cur_obj_state], dim=-1)
                        wm_input_dict = {
                            'state': wm_states,
                            'action': wm_actions
                        }
                    else:
                        wm_input_dict = {
                            'state': unflatten_states, 
                            'action': stacked_compensated_actions
                        }
                        
                    tot_pred_nex_state = self.wm_model(wm_input_dict)
                    # tot_pred_nex_state = tot_pred_nex_state[:, 0]

                else:
                    tot_pred_nex_state = []
                    for joint_idx in self.wm_pred_joint_idxes:
                        if self.add_nearing_finger:
                            cur_finger_idx = joint_idx // 4
                            cur_states = unflatten_states[:, :, cur_finger_idx * 4: (cur_finger_idx + 1) * 4]
                            cur_actions = stacked_compensated_actions[:, :, cur_finger_idx * 4: (cur_finger_idx + 1) * 4]
                            cur_joint_states = cur_states.contiguous().view(cur_states.size(0), -1)
                            cur_joint_actions = cur_actions.contiguous().view(cur_actions.size(0), -1)
                        else:
                            cur_joint_states = unflatten_states[:, :, joint_idx]
                            cur_joint_actions = stacked_compensated_actions[:, :, joint_idx]
                            if self.wm_w_neighbouring:
                                if joint_idx % 4 == 0:
                                    bf_joint_idx = joint_idx
                                else:
                                    bf_joint_idx = joint_idx - 1
                                    
                                if (joint_idx + 1) % 4 == 0:
                                    af_joint_idx = joint_idx
                                else:
                                    af_joint_idx = joint_idx + 1
                                cur_bf_joint_state = unflatten_states[:, -1, [bf_joint_idx]]
                                cur_af_joint_state = unflatten_states[:, -1, [af_joint_idx]]
                                cur_bf_joint_action =  compensated_actions[:, [bf_joint_idx]]
                                cur_af_joint_action = compensated_actions[:, [af_joint_idx]]
                                
                                cur_joint_states = torch.cat(
                                    [ cur_joint_states, cur_bf_joint_state, cur_af_joint_state ], dim=-1
                                )
                                cur_joint_actions = torch.cat(
                                    [ cur_joint_actions, cur_bf_joint_action, cur_af_joint_action ], dim=-1
                                )
                        cur_joint_input_dict = { 
                            'state': cur_joint_states, 
                            'action': cur_joint_actions,
                        }
                        cur_joint_pred_nex_state = self.joint_idx_to_wm[joint_idx](cur_joint_input_dict)
                        tot_pred_nex_state.append(cur_joint_pred_nex_state)
                    tot_pred_nex_state = torch.cat(tot_pred_nex_state, dim=-1)
  
                per_joint_loss = torch.mean((tot_pred_nex_state - pred_nex_state_sim) ** 2, dim=0).detach().cpu()
                for i_j in range(16):
                    if i_j not in joint_idx_to_loss_curve:
                        joint_idx_to_loss_curve[i_j] = []
                    joint_idx_to_loss_curve[i_j].append(per_joint_loss[i_j].item())
                    
                    if i_j not in joint_idx_to_ep_avg_loss:
                        joint_idx_to_ep_avg_loss[i_j] = []
                    joint_idx_to_ep_avg_loss[i_j].append(sum(joint_idx_to_loss_curve[i_j]) / float(len(joint_idx_to_loss_curve[i_j])))
                
                diff_acts = torch.sum((tot_pred_nex_state - pred_nex_state_sim) ** 2, dim=-1)
                diff_acts = diff_acts.mean()
            else:
                compensate_actions = actions[..., -pred_delta_actions.size(-1):] + pred_delta_actions
                compensate_actions = torch.cat(
                    [ actions[..., : -pred_delta_actions.size(-1)] , compensate_actions] , dim=-1
                )
                input_dict_real = {
                    key: data_batch[key].to(self.device) for key in data_batch
                }
                input_dict_real.update(
                    {
                        'state': state,
                        'nex_state': nex_state,
                        'action': compensate_actions
                    }
                ) 
                # 
                
                pred_nex_state = model(input_dict_real)
            
                
                with torch.no_grad():
                    detached_input_dict_real = {
                        key: input_dict_real[key].detach() for key in input_dict_real
                    }
                    detached_input_dict_real['action'] = unscaled_compensated_action
                    pred_nex_state_compensated = model(detached_input_dict_real)
                    diff_acts_compensated = torch.sum((pred_nex_state_compensated - pred_nex_state_sim) ** 2, dim=-1)
                    diff_acts_compensated = diff_acts_compensated.mean().item()
                    loss_name_to_ep_loss['loss_compensated_policy'].append(diff_acts_compensated)
                
                if self.invdyn_model_optimize_via_fingertip_pos:
                    unscaled_pred_nex_state = train_loader.dataset.unscale_states(pred_nex_state)
                    unscaled_pred_nex_state_sim = train_loader.dataset.unscale_states(pred_nex_state_sim)
                    pred_nex_finger_pos, pred_nex_finger_rot_quat = self.forward_pk_chain_for_finger_pos(unscaled_pred_nex_state, self.invdyn_finger_idx)
                    pred_nex_finger_pos_sim, pred_nex_finger_rot_quat_sim = self.forward_pk_chain_for_finger_pos(unscaled_pred_nex_state_sim, self.invdyn_finger_idx)
                    diff_pred_fingerpos_w_sim = torch.sum(
                        ( pred_nex_finger_pos - pred_nex_finger_pos_sim ) ** 2, dim=-1
                    )
                    
                    diff_acts = diff_pred_fingerpos_w_sim.mean( ) #  + finger_rot_dist.mean() * self.invdyn_model_fingertip_rot_coef
                else:
                    diff_acts = torch.sum((pred_nex_state - pred_nex_state_sim) ** 2, dim=-1)
                    diff_acts = diff_acts.mean()
            
            
            loss = diff_acts
            
            loss_name_to_ep_loss['loss'].append(loss.detach().item())
            loss_name_to_ep_loss['task_loss'].append(diff_acts.detach().item())
            
            step += 1

            optimizer.zero_grad()
            loss.backward()
            
            
            try:
                torch.nn.utils.clip_grad_norm_(
                    delta_action_model.parameters(), config.optim.grad_clip
                )
            except Exception:
                pass
            optimizer.step()

            if self.config.model.ema:
                ema_helper.update(delta_action_model)
            
            if (step % logging_step_interval == 0) and rank == 0:
                avg_ep_loss = sum(loss_name_to_ep_loss['loss']) / float(len(loss_name_to_ep_loss['loss']))
                if len(ep_reg_sigma) == 0:
                    avg_ep_reg_sigma = 0.0
                else:
                    avg_ep_reg_sigma = sum(ep_reg_sigma) / float(len(ep_reg_sigma))
                
                if len(ep_extrin) == 0:
                    avg_ep_extrin = 0.0
                else:
                    avg_ep_extrin = sum(ep_extrin) / float(len(ep_extrin))
                
                tb_logger.add_scalar("loss", avg_ep_loss, global_step=step)
                
                logging_info_str = f"step: {step}, loss: {avg_ep_loss}, "
                for key in loss_name_to_ep_loss:
                    if len(loss_name_to_ep_loss[key]) > 0:
                        avg_ep_loss = sum(loss_name_to_ep_loss[key]) / float(len(loss_name_to_ep_loss[key]))
                        logging_info_str += f"{key}: {avg_ep_loss}, "
                logging_info_str += f"avg_ep_extrin: {avg_ep_extrin}, data time: {data_time / (i+1)}"
                logging.info(logging_info_str)
                
                
                # Plot 16 line charts for all joints in wm_pred_delta_abs
                import matplotlib.pyplot as plt
                import matplotlib
                matplotlib.use('Agg')  # Use non-interactive backend
                
                fig, axes = plt.subplots(4, 4, figsize=(20, 16))
                fig.suptitle('World Model Prediction Delta Absolute Values by Joint', fontsize=16)
                
                for joint_idx in range(16):
                    row = joint_idx // 4
                    col = joint_idx % 4
                    ax = axes[row, col]
                    
                    if joint_idx in joint_idx_to_ep_avg_loss and len(joint_idx_to_ep_avg_loss[joint_idx]) > 0:
                        data_plot = joint_idx_to_ep_avg_loss[joint_idx]
                        ax.plot(data_plot, linewidth=1, alpha=0.8)
                        ax.set_title(f'Joint {joint_idx}')
                        ax.set_xlabel('Steps')
                        ax.set_ylabel('Delta Abs')
                        ax.grid(True, alpha=0.3)
                    else:
                        ax.text(0.5, 0.5, f'No data for Joint {joint_idx}', 
                            ha='center', va='center', transform=ax.transAxes)
                        ax.set_title(f'Joint {joint_idx}')
                
                plt.tight_layout()
                plot_file_path = os.path.join(self.args.log_path, 'wm_pred_delta_abs_plots.png')
                plt.savefig(plot_file_path, dpi=300, bbox_inches='tight')
                plt.close()
                
                joint_idx_to_avg_loss_sv_fn = f'joint_idx_to_ep_avg_loss.npy'
                joint_idx_to_avg_loss_sv_fn = os.path.join(self.args.log_path, joint_idx_to_avg_loss_sv_fn)
                np.save(joint_idx_to_avg_loss_sv_fn, joint_idx_to_ep_avg_loss)
                
                

            if (step % self.config.training.snapshot_freq == 0 or step == 1) and rank == 0:
                states = [
                    delta_action_model.module.state_dict(),
                    optimizer.state_dict(),
                    epoch,
                    step,
                ]
                
                if self.normalize_input:
                    states.append(self.running_mean_std.state_dict())
                
                if self.config.model.ema:
                    states.append(ema_helper.state_dict())

                torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))

            data_start = time.time()
            
            i += 1
    
    def cleanup():
        dist.destroy_process_group()
    cleanup()


import yaml
from isaacgymenvs.ddim.main import dict2namespace




def train_world_model_ddp(self):
    
    device_id = setup_ddp_flexible()
    
    rank = device_id
    
    world_size = 8
    
    # delta_action_scale = 1/24
    self.device = rank
    self.config.device = device_id
    
    
    args, config = self.args, self.config
    tb_logger = self.config.tb_logger
    
    dataset, eval_dataset = get_dataset(args, config)
    
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)

    train_loader = data.DataLoader(
        dataset,
        batch_size=config.training.batch_size,
        # shuffle=True,
        sampler=sampler,
        num_workers=config.data.num_workers,
    )
    
    self.use_sepearate_test_data = self.config.invdyn.use_sepearate_test_data
    # self.use_sepearate_test_data = True
    self.seperate_test_data_fn = self.config.invdyn.seperate_test_data_fn
    
    
    eval_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
    eval_loader = data.DataLoader(
        eval_dataset,
        batch_size=config.training.batch_size,
        sampler=eval_sampler,
        num_workers=config.data.num_workers,
    )
    
    
    if self.use_sepearate_test_data:
        test_dataset = ControlSeqWorldModel(
            self.seperate_test_data_fn, 'qpos_invdyn', history_length=self.config.invdyn.history_length, future_length=self.config.invdyn.future_length, res=2, config=config, split=None
        )
        test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank)
        test_loader =  data.DataLoader(
            test_dataset,
            batch_size=config.training.batch_size,
            # shuffle=True,
            sampler=test_sampler,
            num_workers=config.data.num_workers,
        )
        
    
    
    model = WorldModel(self.config)
    
    model = model.to(self.device)
    optimizer = get_optimizer(self.config, model.parameters())

    if self.config.model.ema:
        ema_helper = EMAHelper(mu=self.config.model.ema_rate)
        ema_helper.register(model)
    else:
        ema_helper = None
    
    
    ### load world model from ckpt ? ###
    print(f"load_pretrained_wm_ckpt: {self.config.invdyn.load_pretrained_wm_ckpt}")
    if len(self.config.invdyn.load_pretrained_wm_ckpt) > 0 and os.path.exists(self.config.invdyn.load_pretrained_wm_ckpt):
        print(f"Loading world model from {self.config.invdyn.load_pretrained_wm_ckpt}")
        loaded_ckpt = torch.load(self.config.invdyn.load_pretrained_wm_ckpt, map_location=f'cuda:{self.device}')
        model.load_state_dict(loaded_ckpt[0])
        model = model.to(self.device)
        if self.config.model.ema:
            ema_helper.load_state_dict(loaded_ckpt[-1])
            ema_helper.ema(model)
        # model.eval()

    
    if self.invdyn_train_residual_wm:
        print(f"Loading previous world model from {self.invdyn_prev_wm_ckpt}")
        prev_hist_context_length = self.config.invdyn.hist_context_length + 0
        self.config.invdyn.hist_context_length = 0
        prev_model = WorldModel(self.config)
        prev_model = prev_model.to(self.device)
        prev_model.load_state_dict(torch.load(self.invdyn_prev_wm_ckpt, map_location=self.device)[0])
        
        prev_ema_helper = EMAHelper(mu=self.config.model.ema_rate)
        prev_ema_helper.register(prev_model)
        prev_ema_helper.load_state_dict(torch.load(self.invdyn_prev_wm_ckpt, map_location=self.device)[-1])
        prev_ema_helper.ema(prev_model)
        
        self.config.invdyn.hist_context_length = prev_hist_context_length
        
        prev_model.eval()
    

    start_epoch, step = 0, 0
    
    logging_step_interval = self.config.training.logging_step_interval # 20000
    
    best_test_loss = 1e10
    
    tot_ep_test_loss = []
    tot_ep_train_loss = []
    
    for epoch in range(start_epoch, self.config.training.n_epochs):
        data_start = time.time()
        data_time = 0
        
        ep_reg_sigma = []
        ep_extrin = []
        
        loss_name_to_ep_loss = { 'loss': [] , 'task_loss': []}
        
        i = 0
        
        # for i, data_batch in enumerate(train_loader):
        for data_batch in tqdm(train_loader):
            state = data_batch['state']
            actions = data_batch['action']
            
            nex_state = data_batch['nex_state']
            
            
            data_time += time.time() - data_start
            model.train()
            
            state = state.to(self.device)
            nex_state = nex_state.to(self.device)
            actions = actions.to(self.device)
            
            
            input_dict = {
                key: data_batch[key].to(self.device) for key in data_batch
            }
            pred_nex_state = model(input_dict)
            
            if self.invdyn_train_residual_wm:
                with torch.no_grad():
                    prev_model.eval()
                    prev_pred_nex_state = prev_model(input_dict)
                    
                    nex_state = nex_state - prev_pred_nex_state
                    
            diff_acts = torch.sum((pred_nex_state - nex_state) ** 2, dim=-1)
            diff_acts = diff_acts.mean()
            
            
            loss = diff_acts
            
            loss_name_to_ep_loss['loss'].append(loss.detach().item())
            loss_name_to_ep_loss['task_loss'].append(diff_acts.detach().item())
                
            
            step += 1

            optimizer.zero_grad()
            loss.backward()
            
            
            try:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.optim.grad_clip
                )
            except Exception:
                pass
            optimizer.step()

            if self.config.model.ema:
                ema_helper.update(model)
                
            if step % logging_step_interval == 0 and rank == 0:
                
                ### model eval ###
                with torch.no_grad():
                    model.eval()
                    
                    ###### EVAL MODEL ######
                    tot_test_loss = []
                    tot_input_output = {}
                    for test_data_batch in tqdm(eval_loader):
                        nex_gt_state = test_data_batch['nex_state'].to(self.device)
                        test_input_dict = {
                            key: test_data_batch[key].to(self.device) for key in test_data_batch
                        }
                        test_pred_nex_state = model(test_input_dict)
                        
                        # get the input and output #
                        if 'gt_nex_state' not in tot_input_output:
                            tot_input_output['gt_nex_state'] = []
                        if 'pred_nex_state' not in tot_input_output:
                            tot_input_output['pred_nex_state'] = []
                        tot_input_output['gt_nex_state'].append(nex_gt_state.detach().cpu().numpy())
                        tot_input_output['pred_nex_state'].append(test_pred_nex_state.detach().cpu().numpy())
                        
                        # if self.invdyn_train_residual_wm:
                        #     with torch.no_grad():
                        #         prev_model.eval()
                        #         test_prev_pred_nex_state = prev_model(test_input_dict)
                                
                        #         nex_gt_state = nex_gt_state - test_prev_pred_nex_state
                        
                        test_diff_acts = torch.sum((test_pred_nex_state - nex_gt_state) ** 2, dim=-1)
                        test_diff_acts = test_diff_acts.mean()
                        tot_test_loss.append(test_diff_acts.detach().item())
                    tot_test_loss = sum(tot_test_loss) / float(len(tot_test_loss))
                    # model.train()
                    
                    
                    if tot_test_loss < best_test_loss:
                        best_test_loss = tot_test_loss
                        states = [
                            model.state_dict(),
                            optimizer.state_dict(),
                            epoch,
                            step,
                        ]
                        if self.normalize_input:
                            states.append(self.running_mean_std.state_dict())
                        
                        if self.config.model.ema:
                            states.append(ema_helper.state_dict())

                        torch.save(states, os.path.join(self.args.log_path, "ckpt_eval_best.pth"))
                    ###### EVAL MODEL ######
                    
                    
                    if self.use_sepearate_test_data:
                        ###### TEST MODEL ######
                        # test_sv_dict = 
                        tot_real_test_loss = []
                        for test_data_batch in tqdm(test_loader):
                            nex_gt_state = test_data_batch['nex_state'].to(self.device)
                            test_input_dict = {
                                key: test_data_batch[key].to(self.device) for key in test_data_batch
                            }
                            test_pred_nex_state = model(test_input_dict)
                            
                            test_diff_acts = torch.sum((test_pred_nex_state - nex_gt_state) ** 2, dim=-1)
                            test_diff_acts = test_diff_acts.mean()
                            tot_real_test_loss.append(test_diff_acts.detach().item())
                        tot_real_test_loss = sum(tot_real_test_loss) / float(len(tot_real_test_loss))
                    else:
                        tot_real_test_loss = tot_test_loss
                    model.train()
                    ###### TEST MODEL ######
                # # ep loss # #
                # # ep loss # #
                avg_ep_loss = sum(loss_name_to_ep_loss['loss']) / float(len(loss_name_to_ep_loss['loss']))
                
                tb_logger.add_scalar("loss", avg_ep_loss, global_step=step)
                
                logging_info_str = f"step: {step}, loss: {avg_ep_loss}, "
                for key in loss_name_to_ep_loss:
                    if len(loss_name_to_ep_loss[key]) > 0:
                        avg_ep_loss = sum(loss_name_to_ep_loss[key]) / float(len(loss_name_to_ep_loss[key]))
                        logging_info_str += f"{key}: {avg_ep_loss}, "
                logging_info_str += f"eval_loss: {tot_test_loss}, test_loss: {tot_real_test_loss}, data time: {data_time / (i+1)}"

                logging.info(logging_info_str)
                
                
            
            if step % self.config.training.snapshot_freq == 0 or step == 1 and rank == 0:
                states = [
                    model.state_dict(),
                    optimizer.state_dict(),
                    epoch,
                    step,
                ]
                
                if self.normalize_input:
                    states.append(self.running_mean_std.state_dict())
                
                if self.config.model.ema:
                    states.append(ema_helper.state_dict())

                torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))

            data_start = time.time()
            
            i += 1
        
        if rank == 0:
            ## Save ##
            states = [
                model.state_dict(),
                optimizer.state_dict(),
                epoch,
                step,
            ]
            
            if self.normalize_input:
                states.append(self.running_mean_std.state_dict())
            
            if self.config.model.ema:
                states.append(ema_helper.state_dict())
            
            torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))
            
    def cleanup():
        dist.destroy_process_group()
    cleanup()

